import json
import os

import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
import numpy as np

from Causal_Partial_Mnist.Find_CF_Synthetic_Distribution_Mnist import get_bayesian_network, get_bn


def get_cf_dist(Exp, targetVars, cf_intervene, cf_evidence, expr, load_dist=False):
    ## Counterfactual SCM


    # if load_dist==True:
    intvals = "".join(str(x) for x in cf_intervene.values())
    cfvals = "".join(str(x) for x in cf_evidence.values())
    file_name = Exp.Cf_SCMs + expr + intvals+cfvals + ".txt"
    if os.path.exists(file_name):
        with open(file_name) as f:
            data = f.read()
        save_dist_dict = json.loads(data)
        true_dist = {eval(dist):save_dist_dict[dist] for dist in save_dist_dict }
        return true_dist



    with open(Exp.SCM_PATH) as f:
        data = f.read()
    INSTANCE = json.loads(data)


    bnc = gum.BayesNet("counterfactual SCM")


    for label in Exp.Twin_Network:
        if label in INSTANCE["noise_dist"]:
            bnc.add(label, Exp.latent_state)

        else:
            bnc.add(label, Exp.label_dim[label]["feature"])

        for parent in Exp.Twin_Network[label]:
            bnc.addArc(*(parent, label))

    added_latents = []
    for label in Exp.cf_exogenous:
        if Exp.cf_exogenous[label] not in added_latents:
            bnc.add(Exp.cf_exogenous[label], Exp.noise_states)
            added_latents.append(Exp.cf_exogenous[label])

        bnc.addArc(*(Exp.cf_exogenous[label], label))

    # # assign probabilities to noise
    for noise in INSTANCE["noise_dist"]:
        if noise == "nX1":
            continue
        bnc.cpt(noise).fillWith(INSTANCE["noise_dist"][noise])


    for label in Exp.cflabel_names:
        lb = label
        if label == "X1p":
            lb = "X1"
        if label == "X2p":
            lb = "X2"

        probs = INSTANCE["cpt"][lb]['feature']

        if label in INSTANCE["noise_dist"]:  # common confounders
            bnc.cpt(label).fillWith(INSTANCE["noise_dist"][label])
            continue

        for i, ins in enumerate(bnc.cpt(label).loopIn()):
            bnc.cpt(label).set(ins, probs[i])

        if label in cf_intervene.keys():
            lst = [0 for i in range(Exp.label_dim[label]['feature'])]
            lst[cf_intervene[label]] = 1
            bnc.cpt(label).fillWith(lst)

    # testing if the color is still same
    ie = gum.LazyPropagation(bnc)
    var_set = set(Exp.cflabel_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    # print(ie.evidenceJointImpact(["Ycolor"], []))


    # Counterfactual Inference
    # P(Y|X1'=1,X2'=1,do(X1=0,X2=0))
    true_bn = gum.BayesNet(bnc)
    true_bn = get_bn(Exp, true_bn, cf_intervene)

    ie = gum.LazyPropagation(true_bn)
    var_set = set(Exp.cflabel_names)
    ie.addJointTarget(var_set)
    ie.makeInference()



    # targetVarsdist=[]
    for Yvar in targetVars:
        cfYvar = gum.getPosterior(true_bn, evs=cf_evidence, target=Yvar)
        Yvar_dict = {}
        for i in cfYvar.loopIn():
            comb = top_sort_dict(i.todict(), Exp.Complete_DAG.keys())
            Yvar_dict[tuple(comb.values())] = cfYvar[i]

    targetVarsdist= Yvar_dict
    # targetVarsdist.append(Yvar_dict)


    # saving
    save_dist_dict = {str(dist): targetVarsdist[dist] for dist in targetVarsdist}
    intvals = "".join(str(x) for x in cf_intervene.values())
    cfvals = "".join(str(x) for x in cf_evidence.values())
    file_name = Exp.Cf_SCMs + expr + intvals+cfvals + ".txt"
    print(f"Saving {expr} at {file_name}")
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_dist_dict))


    return targetVarsdist





def get_cf_from_intervs(Exp, targetVars, cf_intervene, cf_evidence, expr):
    Y_var= targetVars[0]

    true_bn, _ = get_bayesian_network(Exp, {}, load_scm=1)
    true_bn =true_bn["feature"]
    # P(w|do(x1=0, x2=0))
    bn3 = gum.BayesNet(true_bn)
    bn3 = get_bn(Exp, bn3, cf_intervene)

    ie = gum.LazyPropagation(bn3)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    res3 = ie.evidenceJointImpact(["W"], [])
    # gnb.flow.row(res3, captions=["P(W|do(X1=0,X2=0))"])
    # gnb.flow.row(bn.cpt("X1"))

    #
    bn1 = gum.BayesNet(true_bn)
    bn1 = get_bn(Exp, bn1, {"X1":cf_intervene["X1"], "W": 0})

    ie = gum.LazyPropagation(bn1)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    res1 = ie.evidenceJointImpact([Y_var, "X2"], [])
    res1 = res1 * res3[0]

    for state_val in range(1, Exp.label_dim["W"]["feature"]):
        bn2 = gum.BayesNet(true_bn)
        bn2 = get_bn(Exp, bn2, {"X1": cf_intervene["X1"], "W": state_val})

        ie = gum.LazyPropagation(bn2)
        var_set = set(Exp.label_names)
        ie.addJointTarget(var_set)
        ie.makeInference()
        res2 = ie.evidenceJointImpact([Y_var, "X2"], [])
        #     gnb.flow.row(res2, captions=["P(Ydigit, X2'|do(X1'=1,W=1))"])

        res2 = res2 * res3[state_val]
        #     gnb.flow.row(res1,  res2)

        res2ara = res2.toarray()
        res2ara = res2ara.ravel()

        for id, i in enumerate(res1.loopIn()):
            res1.set(i, res1.get(i) + res2ara[id])

    resx = res1.margSumOut(Y_var).normalize()
    # gnb.flow.row(res1, resx, res1 / resx, captions=["numerator", "denominator", "result in bn: Estimation:\
    # P(Ydigit|X1=1,X2=1,do(X1=0,X2=0))"])
    print(res1 / resx)

